Practical Federated Learning with TensorFlow and TensorFlowJS— Part 1
TLDR;
GitHub — Code
Overview
Over the past two decade Artificial Intelligence(AI) and Machine Learning(ML) have made great leaps in terms of progress. This all became possible with robust advancements in semi-conductor technologies and decrease in price of GPU(s) for neural computation. With mobile companies like Apple and Samsung advertising their prowess in the same, Artificial Intelligence and Machine Learning have become two such terms that even the general populous has gained much familiarity with.
One of the foundational necessities of Machine Learning is data. At grass root level Machine Learning is finding a function that can map inputs to outputs via iterative processes.
But, abundance in data gathering practices have given rise to concern like data privacy and data leakage. Consumer awareness has led to formulation to reforms General Data Protection Regulation. Yet, more than a few tech companies have been under scrutiny because of data privacy laws including tech giants like Facebook and Google.
Another problem which traditional training algorithms is that companies which did not have resources to acquire and store data in server farms either lack data or have poor quality data. This can lead to inadequately trained models.
Federated Learning is a novel solution proposed to solve this challenge of maintaining data integrity and security while still training the models.
This approach employs an averaging algorithm which permits models to be trained on our devices and then only upload the model weights to the cloud server.
Here, the model weights can be averaged with other users to prevent predictive analysis on a single user thus losing the perchance of data leakage.
The idea has been around for a while now, but only recently with advent of more powerful chip-sets in our mobile devices has this approach become a reality.
Federated Learning is the decentralisation of training process of machine learning models. This makes the communication much more expensive as compared to that in classical environments like that in a data server. Google quotes it as ‘Collaborative Machine Learning without Centralized Training Data’.
Federated Learning can also be seen as a major step towards Democratization of AI.
Federated Learning revolves around these four steps:
- Select a small subset of client devices which will download the trained model.
- This subset trains the model on data either generated by the client or provided to the client.
- The model updates are sent to the server.
- The server uses an averaging algorithm to account for the changes sent by the client.
The communication costs drastically increase due to large model sizes and one of the known bottlenecks of federated learning. Structured updates are a way to tackle this problem where only a set of parameterised changes are compressed and sent to the server.
Google Gboard on Android, already employs a this technique to update their ML models. They have employed a miniature version of TensorFlow on the devices for the same. They have employed a Secure Aggregation protocol which decrypts data only after at least 100 updates.
Implementation
Instead of creating a distributed network and we will use a single model which will be distributed to all the clients.
- The dataset used is MNIST.
- We will create initial model in Python.
- Using
tensorflowjs_converter
,we will convertmodel.h5
tomodel.json
for tfjs. - We will simulate clients using multiple CPUs with each CPU acting as a client.
- A subset of data(skipping i.i.d) and the server model is given to each client on which it trains and saves weights in
model.json
. - Using
tensorflowjs_converter
convert model.json toh5
files. - We will average all models and update the server model which can again be distributed to clients.
Now let’s look at an implementation of federated learning in TensorFlow.
Dependencies — GitHub
git clone https://github.com/aniketbiprojit/Federated-Learning-with-TensorFlow federated
cd federated
pip install -r requirements.txt
npm i
Note: This implementation is only for understanding how federated learning not an aim to build a scalable and deployable model. We will look at more verbose model in part 2 of this series.
Python Model
First, we have created a simple Keras model. We will be using MNIST dataset for training the model.
This is a very simple model. As model complexity increases, the cost of communication becomes more expensive. This is an important parameter that needs to be kept in mind while designing of federated learning models.
We will be using TensorFlowJS to iteratively train our model on client devices.
To do this we need to convert our model back and froth from Keras HD5 to models.json
We will need TFJS for same. It is recommended to install tfjs in a separate env to prevent messing up with our dependencies.
I will be using the following lines of code for inter-conversion of models.
The first thing we do is convert our now saved model to tfjs readable format with to_tfjs.sh
Data Loading
Now, next we will be setting up our node environment for tfjs models. This step can easily be converted to a browser based method if the inputs are served over REST API.
We won’t be using browsers(too much hassle to save and load models) for this implementation but rather simulate our clients with node using multiple CPUs. Each CPU here representing one client.
Our friends at StackOverflow were able to help us quickly load the MNIST data into our node environment. Then, we will run some pre-processing on this loaded data.
The above function returns an array of objects(dict) with label as key and image array as value.
Layers API model
For our node environment, we will be using @tensorflow/tfjs-node to load our models.
As, we are using tfjs-node, we can directly load our models from native file system.
Pre-processing Data
Now we have our data and model in place but we need to update our data format to suit that of tfjs model.
The current data format looks like [{‘5’:[…]},{‘3’}:[…] which is a bit inconvenient for tfjs model to read.
This gives us a X_tensor and y_tensor for training out data. Now the training loop. We just looped through our data and converted them into nice tensors.
Train Model
If you have familiarity with Keras, the above code should feel just at home.
Clients
Let’s go ahead and simulate our clients with cluster in node.
The above code simulates 8 clients which run the model on different sets of data. We had loaded 1024*8 images in our data loader. We make sets of 8 sets of these images like 0–1023, 1024–2047 and so on and then train the model on 10 epochs.
We will now get a 8 models trained on our clients in tfjs-models folder.
Back to Keras
Using our script we are going to convert our tfjs models to Keras models.
Once we have all the models in h5, we can use the weights for averaging:
Averaging Process
For sake of simplicity we have used tf.reduce_mean
to calculate average weights. This function takes an array of Tensors and return axis-wise average.
updated_weights = [tf.reduce_mean(weights[0], 0), tf.reduce_mean(weights[1], 0)]
This trained model can further be sent to rest of the clients for training.
Challenges
One of the major pitfalls of this implementation is that the data is not i.i.d because we randomly sampled our data-set. This might lead to fallacies in our trained model. This is a real challenge when it comes to Federated Learning. Here we could have easily manipulated data to fit our needs but we skipped that step keeping in mind the real world implications of the same.
Federated Learning also introduces another risk of a model poisoning into the system. An injector could potentially train a model to be equally accurate but towards a set of a single feature and send the updates to server.
Further Reading
Federated Learning: Collaborative Machine Learning without Centralized Training Data
Federated Machine Learning: Concept and Applications
How To Backdoor Federated Learning
Federated Learning for Image Classification
More Technologies and Frameworks for Federated Learning:
- PySyft was created for a secure and private Deep Learning. It uses decouplisation of private data from model training, using Federated Learning, Differential Privacy, and Encrypted Computation (like Multi-Party Computation (MPC) and Homomorphic Encryption (HE) within the main Deep Learning frameworks like PyTorch and TensorFlow.
- Federated Learning Algorithms In FATE
- TensorFlow Federated is an open-source framework for machine learning and other computations on decentralized data.
Here, we haven’t used the Matched Averaging(FedAvg). In the next part of this series we will be using the above mentioned frameworks and FedAvg to create a complete end-to-end Federated Learning model.